ENH: cov: expose correction and weights parameters#690
ENH: cov: expose correction and weights parameters#690bruAristimunha wants to merge 13 commits intodata-apis:mainfrom
Conversation
Resolves data-apis#688. Adds `axis`, `correction`, `frequency_weights`, and `weights` to `cov`, giving users control over the degrees-of-freedom correction and the observation-axis / weighted variants that `numpy.cov` and `torch.cov` already support. Naming follows array-api conventions (`axis`, `correction`) rather than numpy's (`rowvar`, `bias`, `ddof`); the docstring includes a one-to-one mapping. The delegation moves observations to the last axis via `xp.moveaxis`, collapsing `rowvar` out of the backend dispatch — only `ddof` vs `correction` differs between branches. Dask's native `cov` forces `.compute()` on a lazy scalar when any weights are given, so weighted dask inputs fall through to the generic implementation, which is fully lazy.
|
It looks like the The PR description mentions that other functions in this library already use |
|
Hey @betatim! This was a little hard decision that I had to make, but I can be more strict with numpy if you prefer. I basically looked at what was already implemented on the API array and how they handle the parameter names that I was trying to implement. Like, for each parameter that I was trying to introduce, I checked how it was made in the past here from numpy to: the Basically, for the https://data-apis.org/array-api/latest/API_specification/generated/array_api.var.html There was a discussion on how to use correction instead of bias+ddof on these functions. Here was introduced data-apis/array-api#10, and then, later, they made some interesting discussions here: data-apis/array-api#695; it was @kgryte who led the discussion. For the case of the And for the frequency_weights and weights, it was my experience in Pyriemann that made the decisions. I think the only place that I remember using something similar was the statsmodels (freq_weights, var_weights) that uses https://www.statsmodels.org/stable/generated/statsmodels.genmod.generalized_linear_model.GLM.html#statsmodels.genmod.generalized_linear_model.GLM.freq_weights I think in scikit you guys use sample_weight more, but I can accommodate any request about this. |
betatim
left a comment
There was a problem hiding this comment.
What is your thinking on validating the weights passed in? Things like checking the shapes make sense, that they are all positive (is this actually required? how does it fit with being lazy?)
|
I liked this idea a lot @betatim! I think it will make the check in the library that use api array extra much lighter. |
83b7e1b to
d9701e0
Compare
|
FYI @qbarthelemy and @agramfort |
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
|
Thanks a lot for the detailed answer in #690 (comment) - I didn't realise there was precedent for using What is the "temporary deployed" thing that keeps happening? |
|
it is not me @betatim, i think it something that @lucascolley is pushing in pushing here: #699 |
|
Happy that you liked the response @betatim :) I think I addressed all the points from you and @qbarthelemy, can we merge? |
fixed in bd3652a |
lucascolley
left a comment
There was a problem hiding this comment.
I took an initial look, seems pretty good!
One high level comment @bruAristimunha — could you demonstrate that this works as expected when used in a branch of sklearn? You should be able to change https://github.com/scikit-learn/scikit-learn/blob/06aded051fe6c7c9970b7e13c3669f952a799831/maint_tools/vendor_array_api_extra.sh#L8-L9 to point to this branch and commit hash.
|
hey @betatim, As you have the first covariance PR on scikit, can you help with this small test as requested by @lucascolley?
|
|
hey @lucascolley, I made in my branch that was built on top of @betatim's work for scikit-learn first covariance, you can check more here: scikit-learn/scikit-learn#33600 |
thanks! Would be great if you could take a look, Tim |
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
Addresses review feedback (kgryte, betatim) that the motivation for allowing non-integer correction was not obvious from the docstring: weighted unbiased correction and autocorrelated data both require fractional values.
Adds tests for the 1-D shape and length checks in the generic cov path. Raises the diff coverage for this PR from 93.33% to 100%.
|
hey @lucascolley, I was wondering, can you please approve the CI for the final test? |
Resolves #688.
Summary
axis,correction,frequency_weights, andweightsparameters toxpx.cov, unlocking the degrees-of-freedom and weighted variants thatnumpy.covandtorch.covalready support.axis,correction) used elsewhere in this library rather than numpy's (rowvar,bias,ddof). The docstring includes a one-to-one mapping for users migrating fromnumpy.cov.Design
The delegation moves observations to the last axis via
xp.moveaxis, which collapsesrowvarout of backend dispatch entirely — onlyddof(numpy/cupy/dask/jax) vscorrection(torch) differs between branches.Fallbacks to the generic implementation (
_funcs.cov):m.ndim > 2(batched input, not supported by any native).correction(rejected bynumpy.cov'sddof).dask.array.covforces.compute()on a lazy 0-D scalar via its internalif fact <= 0check. The generic path stays fully lazy because its weighted branch doesn't comparefactto zero (noted in docstring).Weighted formula in
_funcs.covmatches numpy's (algebraically):c = (m_c · w) @ m_c.T / (v1 - correction · v2 / v1).Tests
New
TestCovcases validate againstnp.covas reference:test_correction(integer ddof)test_correction_float(generic-path-only, hand-computed reference)test_axis/test_axis_with_weights/test_axis_out_of_boundstest_frequency_weights/test_weights/test_both_weightstest_batch_with_weightsTest plan
pytest tests/test_funcs.py::TestCov— 126 passed across numpy, torch, jax, dask, array-api-strictpytest tests/test_funcs.pyfull — 4263 passed, 0 failedlefthook run pre-commit— ruff, numpydoc, mypy, pyright, typos all greenlazy_xp_function(cov)asserts 0.compute()calls, holds for weighted path via the fallback